<!DOCTYPE html>
<html>
  <head>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes">
    <title>WebGPU Compute Shaders - Histogram, Video</title>
    <style>
      @import url(resources/webgpu-lesson.css);
html, body {
  margin: 0;       /* remove the default margin          */
  height: 100%;    /* make the html,body fill the page   */
}
canvas {
  display: block;  /* make the canvas act like a block   */
  width: 100%;     /* make the canvas fill its container */
  height: 100%;
}
#start {
  position: fixed;
  left: 0;
  top: 0;
  width: 100%;
  height: 100%;
  display: flex;
  justify-content: center;
  align-items: center;
}
#start>div {
  font-size: 200px;
  cursor: pointer;
}
#info {
  position: absolute;
  top: 0;
  left: 0;
  margin: 0;
  padding: 0.5em;
  background-color: rgba(0, 0, 0, 0.8);
  color: white;
}
    </style>
  </head>
  <body>
    <canvas></canvas>
    <pre id="info">

    </pre>
    <div id="start">
      <div>▶️</div>
    </div>
  </body>
  <script type="module">
import TimingHelper from './resources/js/timing-helper.js';
import NonNegativeRollingAverage from './resources/js/non-negative-rolling-average.js';
// see https://webgpufundamentals.org/webgpu/lessons/webgpu-utils.html#wgpu-matrix
import { mat4 } from '../3rdparty/wgpu-matrix.module.js';

const range = (i, fn) => new Array(i).fill(0).map((_, i) => fn(i));

const fpsAverage = new NonNegativeRollingAverage();
const jsAverage = new NonNegativeRollingAverage();
const histogramAverage = new NonNegativeRollingAverage();
const renderAverage = new NonNegativeRollingAverage();

async function main() {
  const adapter = await navigator.gpu?.requestAdapter();
  const canTimestamp = adapter?.features.has('timestamp-query');
  const device = await adapter?.requestDevice({
    requiredFeatures: [
      ...(canTimestamp ? ['timestamp-query'] : []),
    ],
  });
  if (!device) {
    fail('need a browser that supports WebGPU');
    return;
  }

  const timingHelper = new TimingHelper(device);
  const timingHelper2 = new TimingHelper(device);

  // Get a WebGPU context from the canvas and configure it
  const canvas = document.querySelector('canvas');
  const context = canvas.getContext('webgpu');
  const presentationFormat = navigator.gpu.getPreferredCanvasFormat();
  context.configure({
    device,
    format: presentationFormat,
  });

  const k = {
    chunkWidth: 256,
    chunkHeight: 1,
  };
  const chunkSize = k.chunkWidth * k.chunkHeight;
  const sharedConstants = Object.entries(k).map(([k, v]) => `const ${k} = ${v};`).join('\n');

  const histogramChunkModule = device.createShaderModule({
    label: 'histogram chunk shader',
    code: /* wgsl */ `
      ${sharedConstants}
      const chunkSize = chunkWidth * chunkHeight;
      var<workgroup> bins: array<array<atomic<u32>, 4>, chunkSize>;
      @group(0) @binding(0) var<storage, read_write> chunks: array<array<vec4u, chunkSize>>;
      @group(0) @binding(1) var ourTexture: texture_external;

      const kSRGBLuminanceFactors = vec3f(0.2126, 0.7152, 0.0722);
      fn srgbLuminance(color: vec3f) -> f32 {
        return saturate(dot(color, kSRGBLuminanceFactors));
      }

      @compute @workgroup_size(chunkWidth, chunkHeight, 1)
      fn cs(
        @builtin(workgroup_id) workgroup_id: vec3u,
        @builtin(local_invocation_id) local_invocation_id: vec3u,
      ) {
        let size = textureDimensions(ourTexture);
        let position = workgroup_id.xy * vec2u(chunkWidth, chunkHeight) + 
                       local_invocation_id.xy;
        if (all(position < size)) {
          let numBins = f32(chunkSize);
          let lastBinIndex = u32(numBins - 1);
          var channels = textureLoad(ourTexture, position);
          channels.w = srgbLuminance(channels.rgb);
          for (var ch = 0; ch < 4; ch++) {
            let v = channels[ch];
            let bin = min(u32(v * numBins), lastBinIndex);
            atomicAdd(&bins[bin][ch], 1u);
          }
        }

        workgroupBarrier();

        let chunksAcross = (size.x + chunkWidth - 1) / chunkWidth;
        let chunk = workgroup_id.y * chunksAcross + workgroup_id.x;
        let bin = local_invocation_id.y * chunkWidth + local_invocation_id.x;

        chunks[chunk][bin] = vec4u(
          atomicLoad(&bins[bin][0]),
          atomicLoad(&bins[bin][1]),
          atomicLoad(&bins[bin][2]),
          atomicLoad(&bins[bin][3]),
        );
      }
    `,
  });

  const chunkSumModule = device.createShaderModule({
    label: 'chunk sum shader',
    code: /* wgsl */ `
      ${sharedConstants}
      const chunkSize = chunkWidth * chunkHeight;

      struct Uniforms {
        stride: u32,
      };

      @group(0) @binding(0) var<storage, read_write> chunks: array<array<vec4u, chunkSize>>;
      @group(0) @binding(1) var<uniform> uni: Uniforms;

      @compute @workgroup_size(chunkSize, 1, 1) fn cs(
        @builtin(local_invocation_id) local_invocation_id: vec3u,
        @builtin(workgroup_id) workgroup_id: vec3u,
      ) {
        let chunk0 = workgroup_id.x * uni.stride * 2;
        let chunk1 = chunk0 + uni.stride;

        let sum = chunks[chunk0][local_invocation_id.x] +
                  chunks[chunk1][local_invocation_id.x];
        chunks[chunk0][local_invocation_id.x] = sum;
      }
    `,
  });

  const scaleModule = device.createShaderModule({
    label: 'histogram scale shader',
    code: /* wgsl */ `
      @group(0) @binding(0) var<storage, read> bins: array<vec4u>;
      @group(0) @binding(1) var<storage, read_write> scale: vec4f;
      @group(0) @binding(2) var ourTexture: texture_external;

      @compute @workgroup_size(1, 1, 1) fn cs() {
        let size = textureDimensions(ourTexture);
        let numEntries = f32(size.x * size.y);
        var m = vec4u(0);
        let numBins = arrayLength(&bins);
        for (var i = 0u ; i < numBins; i++) {
          m = max(m, bins[i]);
        }
        scale = max(1.0 / vec4f(m), vec4f(0.2 * f32(numBins) / numEntries));
      }
    `,
  });

  const drawHistogramModule = device.createShaderModule({
    label: 'draw histogram shader',
    code: /* wgsl */ `
      struct OurVertexShaderOutput {
        @builtin(position) position: vec4f,
        @location(0) texcoord: vec2f,
      };

      struct Uniforms {
        matrix: mat4x4f,
        colors: array<vec4f, 16>,
        channelMult: vec4u,
      };

      @group(0) @binding(0) var<storage, read> bins: array<vec4u>;
      @group(0) @binding(1) var<uniform> uni: Uniforms;
      @group(0) @binding(2) var<storage, read_write> scale: vec4f;

      @vertex fn vs(
        @builtin(vertex_index) vertexIndex : u32
      ) -> OurVertexShaderOutput {
        let pos = array(

          vec2f( 0.0,  0.0),  // center
          vec2f( 1.0,  0.0),  // right, center
          vec2f( 0.0,  1.0),  // center, top

          // 2st triangle
          vec2f( 0.0,  1.0),  // center, top
          vec2f( 1.0,  0.0),  // right, center
          vec2f( 1.0,  1.0),  // right, top
        );

        var vsOutput: OurVertexShaderOutput;
        let xy = pos[vertexIndex];
        vsOutput.position = uni.matrix * vec4f(xy, 0.0, 1.0);
        vsOutput.texcoord = xy;
        return vsOutput;
      }

      @fragment fn fs(fsInput: OurVertexShaderOutput) -> @location(0) vec4f {
        let numBins = arrayLength(&bins);
        let lastBinIndex = u32(numBins - 1);
        let bin = clamp(
            u32(fsInput.texcoord.x * f32(numBins)),
            0,
            lastBinIndex);
        let heights = vec4f(bins[bin]) * scale;
        let bits = heights > vec4f(fsInput.texcoord.y);
        let ndx = dot(select(vec4u(0), uni.channelMult, bits), vec4u(1));
        return uni.colors[ndx];
      }
    `,
  });

  const videoModule = device.createShaderModule({
    label: 'our hardcoded external textured quad shaders',
    code: /* wgsl */ `
      struct OurVertexShaderOutput {
        @builtin(position) position: vec4f,
        @location(0) texcoord: vec2f,
      };

      struct Uniforms {
        matrix: mat4x4f,
      };

      @group(0) @binding(2) var<uniform> uni: Uniforms;

      @vertex fn vs(
        @builtin(vertex_index) vertexIndex : u32
      ) -> OurVertexShaderOutput {
        let pos = array(

          vec2f( 0.0,  0.0),  // center
          vec2f( 1.0,  0.0),  // right, center
          vec2f( 0.0,  1.0),  // center, top

          // 2st triangle
          vec2f( 0.0,  1.0),  // center, top
          vec2f( 1.0,  0.0),  // right, center
          vec2f( 1.0,  1.0),  // right, top
        );

        var vsOutput: OurVertexShaderOutput;
        let xy = pos[vertexIndex];
        vsOutput.position = uni.matrix * vec4f(xy, 0.0, 1.0);
        vsOutput.texcoord = xy;
        return vsOutput;
      }

      @group(0) @binding(0) var ourSampler: sampler;
      @group(0) @binding(1) var ourTexture: texture_external;

      @fragment fn fs(fsInput: OurVertexShaderOutput) -> @location(0) vec4f {
        return textureSampleBaseClampToEdge(
            ourTexture,
            ourSampler,
            fsInput.texcoord,
        );
      }
    `,
  });

  const histogramChunkPipeline = device.createComputePipeline({
    label: 'histogram',
    layout: 'auto',
    compute: {
      module: histogramChunkModule,
    },
  });

  const chunkSumPipeline = device.createComputePipeline({
    label: 'chunk sum',
    layout: 'auto',
    compute: {
      module: chunkSumModule,
    },
  });

  const scalePipeline = device.createComputePipeline({
    label: 'max',
    layout: 'auto',
    compute: {
      module: scaleModule,
    },
  });

  const drawHistogramPipeline = device.createRenderPipeline({
    label: 'draw histogram',
    layout: 'auto',
    vertex: {
      module: drawHistogramModule,
    },
    fragment: {
      module: drawHistogramModule,
      targets: [{ format: presentationFormat }],
    },
  });

  const videoPipeline = device.createRenderPipeline({
    label: 'hardcoded video textured quad pipeline',
    layout: 'auto',
    vertex: {
      module: videoModule,
    },
    fragment: {
      module: videoModule,
      targets: [{ format: presentationFormat }],
    },
  });

  const videoSampler = device.createSampler({
    magFilter: 'linear',
    minFilter: 'linear',
  });

  // create a typedarray to hold the values for the uniforms in JavaScript
  const videoUniformValues = new Float32Array(16 * 4);
  // create a buffer for the uniform values
  const videoUniformBuffer = device.createBuffer({
    label: 'uniforms for video',
    size: videoUniformValues.byteLength,
    usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
  });
  const kMatrixOffset = 0;
  const videoMatrix = videoUniformValues.subarray(kMatrixOffset, 16);

  const renderPassDescriptor = {
    label: 'our basic canvas renderPass',
    colorAttachments: [
      {
        // view: <- to be filled out when we render
        clearValue: [0.3, 0.3, 0.3, 1],
        loadOp: 'clear',
        storeOp: 'store',
      },
    ],
  };

  function startPlayingAndWaitForVideo(video) {
    return new Promise((resolve, reject) => {
      video.addEventListener('error', reject);
      if ('requestVideoFrameCallback' in video) {
        video.requestVideoFrameCallback(resolve);
      } else {
        const timeWatcher = () => {
          if (video.currentTime > 0) {
            resolve();
          } else {
            requestAnimationFrame(timeWatcher);
          }
        };
        timeWatcher();
      }
      video.play().catch(reject);
    });
  }

  function waitForClick() {
    return new Promise(resolve => {
      window.addEventListener(
        'click',
        () => {
          document.querySelector('#start').style.display = 'none';
          resolve();
        },
        { once: true });
    });
  }

  const video = document.createElement('video');
  video.muted = true;
  video.loop = true;
  video.preload = 'auto';
  video.src = 'resources/videos/pexels-kosmo-politeska-5750980 (1080p).mp4'; /* webgpufundamentals: url */
  //video.src = 'resources/videos/production_id_4166349 (540p).mp4'; /* webgpufundamentals: url */
  //video.src = 'resources/videos/production_id_5077580 (1080p).mp4'; /* webgpufundamentals: url */
  await waitForClick();
  await startPlayingAndWaitForVideo(video);

  canvas.addEventListener('click', () => {
    if (video.paused) {
      video.play();
    } else {
      video.pause();
    }
  });

  const chunksAcross = Math.ceil(video.videoWidth / k.chunkWidth);
  const chunksDown = Math.ceil(video.videoHeight / k.chunkHeight);
  const numChunks = chunksAcross * chunksDown;

  const chunksBuffer = device.createBuffer({
    size: numChunks * chunkSize * 4 * 4,
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
  });

  const scaleBuffer = device.createBuffer({
    size: 4 * 4,
    usage: GPUBufferUsage.STORAGE,
  });

  const sumBindGroups = [];
  const numSteps = Math.ceil(Math.log2(numChunks));
  for (let i = 0; i < numSteps; ++i) {
    const stride = 2 ** i;
    const uniformBuffer = device.createBuffer({
      size: 4,
      usage: GPUBufferUsage.UNIFORM,
      mappedAtCreation: true,
    });
    new Uint32Array(uniformBuffer.getMappedRange()).set([stride]);
    uniformBuffer.unmap();

    const chunkSumBindGroup = device.createBindGroup({
      layout: chunkSumPipeline.getBindGroupLayout(0),
      entries: [
        { binding: 0, resource: { buffer: chunksBuffer }},
        { binding: 1, resource: { buffer: uniformBuffer }},
      ],
    });
    sumBindGroups.push(chunkSumBindGroup);
  }

  const histogramDrawInfos = [
    [0, 1, 2],
    [3],
  ].map(channels => {
    //        matrix: mat4x4f;
    //        colors: array<vec4f, 16>;
    //        channelMult; vec4u,
    const uniformValuesAsF32 = new Float32Array(16 + 64 + 4 + 4);
    const uniformValuesAsU32 = new Uint32Array(uniformValuesAsF32.buffer);
    const uniformBuffer = device.createBuffer({
      label: 'draw histogram uniform buffer',
      size: uniformValuesAsF32.byteLength,
      usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
    });
    const subpart = (view, offset, length) => view.subarray(offset, offset + length);
    const matrix = subpart(uniformValuesAsF32, 0, 16);
    const colors = subpart(uniformValuesAsF32, 16, 64);
    const channelMult = subpart(uniformValuesAsU32, 16 + 64, 4);
    channelMult.set(range(4, i => channels.indexOf(i) >= 0 ? 2 ** i : 0));
    colors.set([
      [0, 0, 0, 1],
      [1, 0, 0, 1],
      [0, 1, 0, 1],
      [1, 1, 0, 1],
      [0, 0, 1, 1],
      [1, 0, 1, 1],
      [0, 1, 1, 1],
      [0.5, 0.5, 0.5, 1],
      [1, 1, 1, 1],
      [1, 1, 1, 1],
      [1, 1, 1, 1],
      [1, 1, 1, 1],
      [1, 1, 1, 1],
      [1, 1, 1, 1],
      [1, 1, 1, 1],
      [1, 1, 1, 1],
    ].flat());

    const drawHistogramBindGroup = device.createBindGroup({
      layout: drawHistogramPipeline.getBindGroupLayout(0),
      entries: [
        { binding: 0, resource: { buffer: chunksBuffer, size: chunkSize * 4 * 4 }},
        { binding: 1, resource: { buffer: uniformBuffer } },
        { binding: 2, resource: { buffer: scaleBuffer }},
      ],
    });

    return {
      drawHistogramBindGroup,
      matrix,
      uniformBuffer,
      uniformValuesAsF32,
    };
  });

  const infoElem = document.querySelector('#info');

  let then = 0;
  function render(now) {
    now *= 0.001;  // convert to seconds
    const deltaTime = now - then;
    then = now;

    const startTime = performance.now();

    const texture = device.importExternalTexture({source: video});

    // make a bind group for to make a histogram from this video texture
    const histogramBindGroup = device.createBindGroup({
      layout: histogramChunkPipeline.getBindGroupLayout(0),
      entries: [
        { binding: 0, resource: { buffer: chunksBuffer }},
        { binding: 1, resource: texture },
      ],
    });

    const scaleBindGroup = device.createBindGroup({
      layout: scalePipeline.getBindGroupLayout(0),
      entries: [
        { binding: 0, resource: { buffer: chunksBuffer, size: chunkSize * 4 * 4 }},
        { binding: 1, resource: { buffer: scaleBuffer }},
        { binding: 2, resource: texture },
      ],
    });

    const encoder = device.createCommandEncoder({ label: 'histogram encoder' });
    {
      const pass = timingHelper.beginComputePass(encoder);

      // create a histogram for each chunk
      pass.setPipeline(histogramChunkPipeline);
      pass.setBindGroup(0, histogramBindGroup);
      pass.dispatchWorkgroups(chunksAcross, chunksDown);

      // reduce the chunks
      pass.setPipeline(chunkSumPipeline);
      let chunksLeft = numChunks;
      sumBindGroups.forEach(bindGroup => {
        pass.setBindGroup(0, bindGroup);
        const dispatchCount = Math.floor(chunksLeft / 2);
        chunksLeft -= dispatchCount;
        pass.dispatchWorkgroups(dispatchCount);
      });

      // Compute scales for the channels
      pass.setPipeline(scalePipeline);
      pass.setBindGroup(0, scaleBindGroup);
      pass.dispatchWorkgroups(1);

      pass.end();
    }

    // Draw to canvas
    {
      const canvasTexture = context.getCurrentTexture().createView();
      renderPassDescriptor.colorAttachments[0].view = canvasTexture;
      const pass = timingHelper2.beginRenderPass(encoder, renderPassDescriptor);

      // Draw video
      const bindGroup = device.createBindGroup({
        layout: videoPipeline.getBindGroupLayout(0),
        entries: [
          { binding: 0, resource: videoSampler },
          { binding: 1, resource: texture },
          { binding: 2, resource: { buffer: videoUniformBuffer }},
        ],
      });

      // 'cover' canvas
      const canvasAspect = canvas.clientWidth / canvas.clientHeight;
      const videoAspect = video.videoWidth / video.videoHeight;
      const scale = canvasAspect > videoAspect
         ? [1, canvasAspect / videoAspect, 1]
         : [videoAspect / canvasAspect, 1, 1];

      const matrix = mat4.identity(videoMatrix);
      mat4.scale(matrix, scale, matrix);
      mat4.translate(matrix, [-1, 1, 0], matrix);
      mat4.scale(matrix, [2, -2, 1], matrix);

      device.queue.writeBuffer(videoUniformBuffer, 0, videoUniformValues);

      pass.setPipeline(videoPipeline);
      pass.setBindGroup(0, bindGroup);
      pass.draw(6);  // call our vertex shader 6 times

      // Draw Histograms
      histogramDrawInfos.forEach(({
        matrix,
        uniformBuffer,
        uniformValuesAsF32,
        drawHistogramBindGroup,
      }, i) => {
        mat4.identity(matrix);
        mat4.translate(matrix, [-0.95 + i, -1, 0], matrix);
        mat4.scale(matrix, [0.9, 0.5, 1], matrix);

        device.queue.writeBuffer(uniformBuffer, 0, uniformValuesAsF32);

        pass.setPipeline(drawHistogramPipeline);
        pass.setBindGroup(0, drawHistogramBindGroup);
        pass.draw(6);  // call our vertex shader 6 times
      });

      pass.end();
    }

    const commandBuffer = encoder.finish();
    device.queue.submit([commandBuffer]);

    timingHelper.getResult().then(gpuTime => {
      histogramAverage.addSample(gpuTime / 1000);
    });
    timingHelper2.getResult().then(gpuTime => {
      renderAverage.addSample(gpuTime / 1000);
    });

    const jsTime = performance.now() - startTime;

    fpsAverage.addSample(1 / deltaTime);
    jsAverage.addSample(jsTime);

    infoElem.textContent = `\
fps: ${fpsAverage.get().toFixed(1)}
js: ${jsAverage.get().toFixed(1)}ms
histogram gpu: ${canTimestamp ? `${histogramAverage.get().toFixed(1)}µs` : 'N/A'}
rendering gpu: ${canTimestamp ? `${renderAverage.get().toFixed(1)}µs` : 'N/A'}
`;

    requestAnimationFrame(render);
  }
  requestAnimationFrame(render);

  const observer = new ResizeObserver(entries => {
    for (const entry of entries) {
      const canvas = entry.target;
      const width = entry.contentBoxSize[0].inlineSize;
      const height = entry.contentBoxSize[0].blockSize;
      canvas.width = Math.max(1, Math.min(width, device.limits.maxTextureDimension2D));
      canvas.height = Math.max(1, Math.min(height, device.limits.maxTextureDimension2D));
    }
  });
  observer.observe(canvas);
}

function fail(msg) {
  // eslint-disable-next-line no-alert
  alert(msg);
}

main();
  </script>
</html>
